from utils import *
import numpy as np
import sys
import json
from sklearn.metrics import pairwise
from tqdm import tqdm
import random
from sentence_transformers import SentenceTransformer
def nli_classification(args,model,tokenizer,text_a,text_b):
    """whether text_a entails in text_b"""
    tokenized_input_seq_pair = tokenizer.encode_plus(text_b, text_a,
                                                     max_length=args.max_length_cot,
                                                     return_token_type_ids=True, truncation=True)
    input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
    token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
    attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)
    outputs = model(input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    labels=None)

    predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist()  # batch_size only one
    # id2label = {
    #     0: "entailment",
    #     1: "neutral",
    #     2: "contradiction"
    # }
    id2label = {
        0: "entailment",
        1: "contradiction"
    }
    result=id2label[np.argmax(predicted_probability)]
    if result == "entailment":
        answer = "yes"
    else:
        answer = "no"

    return answer

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_answer(args,q,context,flag='gpt'):
    '''gpt -- api ask
       other -- model output
    '''
    prompt = """{context}Q: {q}\nA: Let's think step by step."""
    fill_prompt = prompt.format(context=context,q=q)
    if flag=='gpt':
        prompt_ask = [{"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        answer_responses=responses.choices[0].message.content.replace('\n\n','\n')
        #exact pred ans
        prompt_ask = [
            {"role": "user", "content": fill_prompt},
            {"role": "assistant", "content": answer_responses},
            {"role": "user", "content": args.direct_answer_trigger}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        pred_answer_responses=responses.choices[0].message.content.replace('\n\n','\n')
        
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        if args.shot==0:
            responses = responses.split('\n')[0].replace('\n\n','\n')
        else:
            responses = responses.split('\n\n')[0].replace('\n\n','\n')
    
    return answer_responses,pred_answer_responses

def reverse_get_question(args,answer,context,flag='gpt'):
    instruction = "Give the concrete prompt (question) that can generate this answer. The question should contain all basic and necessary information and Corresponding to the answer. The question only can ask for one result."
    prompt = """{instruction}\n\n{context}\n\nAnswer: {answer}\nCorresponding question: """

    if args.dataset == "aqua":
        answer = '\n'.join(answer.split('\n')[:-1])
    fill_prompt = prompt.format(instruction = instruction , context = context, answer= answer)
    if flag=='gpt':
        prompt_ask = [{"role": "system", "content": "You are a reverse prompt expert."},
                    {"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        reverse_question = responses.choices[0].message.content.replace('\n\n','\n')
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        responses = responses.split('\n\n')[0].replace('\n\n','\n')
    
    generate_question = reverse_question.split('Corresponding question: ')[-1].replace('\n','')
    return generate_question

def segment_question(args,question,context,flag='gpt'):
    instruction = """Please list the conditions and the question of the above text. There may be multiple conditions, but only one question.
Do not list conditions not related to calculations, but list all necessary conditions.
The format should be:
Conditions:
This is your output of conditions. Each line is one condition.
Question:
This is your output of the question."""

    if args.dataset == "aqua":
        question = question.split('Answer Choices')[0]
    
    prompt = """{instruction}\n\n{context}\n\nText: {text}"""
    fill_prompt = prompt.format(instruction = instruction, context = context, text = question)
    if flag == 'gpt':
        prompt_ask = [{"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        condition_question = responses.choices[0].message.content.replace('\n\n','\n')
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        condition_question = responses.split('\n\n')[0].replace('\n\n','\n')


    condition_question = re.split("Conditions:|Question:",condition_question)
    try:
        condition_list = condition_question[1].replace('\n','').split('- ')[1:]
    except:
        condition_list=""

    try:
        question = condition_question[2]
    except:
        question = ""
    return condition_list,question.replace('\n','')

def compare_question(args,ori_question,generate_question,flag='gpt'):
    prompt = """Q1: {q1}
Q2: {q2}

From a mathematical point of view, are these two math word problems ask the same thing at the end?
Please illustrate your reason and answer "yes" or "no"."""
    if args.dataset == "aqua":
        ori_question = ori_question.split('Answer Choices')[0]
    fill_prompt=prompt.format(q1 = ori_question, q2 = generate_question)
    result={}
    result['ori_q'] = ori_question
    result['gen_q'] = generate_question
    if flag == 'gpt':
        prompt_ask = [
                    {"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        compare_result = responses.choices[0].message.content.replace('\n\n','\n')
        if compare_result.lower().startswith("yes"):
            result['answer'] = "yes"
        else:
            result['answer'] = "no"
        result['reason'] = '.'.join(compare_result.split('.')[1:]).replace('\n','')
    elif flag == 'nli':
        compare_result = nli_classification(args,args.load_model,args.tokenizer,ori_question,generate_question)
        if compare_result == "yes":
            result["answer"] = "yes"
            result["reason"] = "They ask the same thing."
        else:
            result["answer"] = "no"
            result["reason"] = "Your answer is not consistent with the meaning of the question. The question asks {ori_q}".format(ori_q=ori_question)
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        condition_question = responses.split('\n\n')[0].replace('\n\n','\n')
        if compare_result.lower().startswith("yes"):
            result['answer'] = "yes"
        else:
            result['answer'] = "no"

        result['reason'] = '.'.join(compare_result.split('.')[1:]).replace('\n','')

    return result
        
def compare_condition_condition_list(args,condition,q_text,ori_to_gen=True,flag='gpt'):

    if args.prompt == "condition list":
        prompt = """Given a candidate condition: \"{condition}\"

Here is a condition list:\"{text}\"

From a mathematical point of view, can this candidate condition be deduced from the condition list?
Please illustrate your reason and answer "yes" or "no"."""

    else:
        prompt = """Given a candidate condition: \"{condition}\"

Here is a context:\"{text}\"

From a mathematical point of view, can this candidate condition be deduced from the context?
Please illustrate your reason and answer "yes" or "no"."""
    if args.dataset == "aqua":
        q_text = q_text.split('Answer Choices')[0]
    
    fill_prompt = prompt.format(condition = condition, text = q_text)
    result={}
    result['condition'] = condition
    if flag == 'gpt':
        prompt_ask = [
                    {"role": "system", "content": "You are a reasoning expert"},
                    {"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        compare_result = responses.choices[0].message.content.replace('\n\n','\n')

        if compare_result.lower().startswith("yes"):
            result['answer'] = "yes"
        else:
            result['answer'] = "no"
        if ori_to_gen == False:
            result['reason'] = '.'.join(compare_result.split('.')[1:])
        else:
            if result['answer'] == "no":
                result['reason'] = "You have ignored a real condition: {condition} The question has mentioned it.".format(condition = condition)
            else:
                result['reason'] = '.'.join(compare_result.split('.')[1:])
    elif flag == 'nli':
        compare_result = nli_classification(args,args.load_model,args.tokenizer,condition,q_text)
        if compare_result == "yes":
            result["answer"] = "yes"
            result["reason"] = "The condition is included in condition list."
        else:
            result["answer"] = "no"
            if ori_to_gen == True:
                result["reason"] = "You ignored a condition that {condition}.".format(condition=condition)
            else:
                result["reason"] = "The condition: {condition} does not match the known conditions of the question, you may have added this new condition, or changed the known conditions of the question.".format(condition=condition)
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        compare_result = responses.split('\n\n')[0].replace('\n\n','\n')
        if compare_result.lower().startswith("yes"):
            result['answer'] = "yes"
        else:
            result['answer'] = "no"

        result['reason'] = '.'.join(compare_result.split('.')[1:])

    return result

def reflex(args,ignored_condition,inconsistent_condition,compare_question_result,ori_condition_list,generate_condition_list,ori_q,generate_question_text,previous_info,flag='gpt'):
    revise_prompt=[]
    if args.prompt=="condition list":
    #ignore condition revise prompt
        ignore_revise_prompt = """You have ignored some real conditions:\n{condition_temp}The real question has the condition list:
{condition_list}
You should consider all real conditions in the context.
Here are detailed illustrations:
{illustration_temp}"""

        inconsistent_revise_prompt = """You use some wrong candidate conditions:\n{condition_temp}They all can not be deduced from the true condition list.
The true question has the condition list:
{condition_list}
You should consider all real conditions in the context.
Here are detailed illustrations:
{illustration_temp}"""
        
        question_revise_prompt = """You misunderstood the question.
You think the question is "{q2}".
But the real question is "{q1}".
They are different. You should consider the real question.
Here is a detailed illustration:
{illustration_temp}"""

    else:
        ignore_revise_prompt = """You have ignored some real conditions:\n{condition_temp}The real question has the context:
{condition_list}
You should consider all real conditions in the context."""


        inconsistent_revise_prompt = """You use some wrong candidate conditions:\n{condition_temp}They all can not be deduced from the true context.
The true question has the context:
{condition_list}
You should consider all real conditions in the context."""

        question_revise_prompt = """You misunderstood the question.
The real question is "{q1}".
You should understand the real question."""

    for i,o in enumerate(ori_condition_list):
        ori_condition_list[i] = f"{i+1}: {o}"

    for i,o in enumerate(generate_condition_list):
        generate_condition_list[i] = f"{i+1}: {o}"

    if len(inconsistent_condition) != 0:
        condition_temp=""
        illustration_temp=""
        for idx,i in enumerate(inconsistent_condition):
            condition_temp = condition_temp + "{idx}: {inconsistent_condition}\n".format(idx=idx+1,inconsistent_condition=i['condition'])
            illustration_temp = illustration_temp + "{idx}: {illustration}\n".format(idx=idx+1,illustration=i['reason'])
        if args.prompt=="condition list":
            inconsistent_revise_fill_prompt = inconsistent_revise_prompt.format(
                                                                condition_list='\n'.join(ori_condition_list),
                                                                condition_temp=condition_temp,
                                                                illustration_temp=illustration_temp)
        else:
            inconsistent_revise_fill_prompt = inconsistent_revise_prompt.format(
                                                                condition_list=ori_q,
                                                                condition_temp=condition_temp,
                                                                illustration_temp=illustration_temp)
        revise_prompt.append(inconsistent_revise_fill_prompt)


    if len(ignored_condition) != 0:
        condition_temp=""
        illustration_temp=""
        for idx,i in enumerate(ignored_condition):
            condition_temp = condition_temp + "{idx}: {ignored_condition}\n".format(idx=idx+1,ignored_condition=i['condition'])
            illustration_temp = illustration_temp + "{idx}: {illustration}\n".format(idx=idx+1,illustration=i['reason'])
        if args.prompt=="condition list":
            ignore_revise_fill_prompt = ignore_revise_prompt.format(
                                                                condition_list='\n'.join(ori_condition_list),
                                                                condition_temp=condition_temp,
                                                                illustration_temp=illustration_temp)
        else:
            ignore_revise_fill_prompt = ignore_revise_prompt.format(
                                                                condition_list=ori_q,
                                                                condition_temp=condition_temp,
                                                                illustration_temp=illustration_temp)
        revise_prompt.append(ignore_revise_fill_prompt)


    if compare_question_result['answer'] != "yes":
        question_revise_fill_prompt = question_revise_prompt.format(
                                                                q1=ori_q,
                                                                q2=compare_question_result['gen_q'],
                                                                illustration_temp=compare_question_result['reason'])
        revise_prompt.append(question_revise_fill_prompt)

    #second request to revise answer
    instruction = "Here are the mistakes and reasons in your answer to the question."
    end_instruction = "You should double-check your answer, analyze the question more deeply, then think more carefully, and finally correct your answer step by step!"
    prompt = """{instruction}\n\n{mistakes}\n\n{end_instruction}"""
    fill_prompt = prompt.format(instruction=instruction,mistakes='\n\n'.join(revise_prompt),end_instruction =end_instruction)
    if flag == 'gpt':
        prompt_ask = [
            {"role": "system", "content": "You are a math expert. You are good at solving math question and correcting answers."},
            {"role": "user", "content": previous_info[0]},
            {"role": "assistant", "content": previous_info[1]},
            {"role": "user", "content": fill_prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        second_ans = responses.choices[0].message.content.replace('\n\n','\n')
        #exact answer
        prompt_ask = [
            {"role": "user", "content": previous_info[0]},
            {"role": "assistant", "content": previous_info[1]},
            {"role": "user", "content": fill_prompt},
            {"role": "assistant", "content": second_ans},
            {"role": "user", "content": args.direct_answer_trigger}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        pred_second_ans=responses.choices[0].message.content.replace('\n\n','\n')
    else:
        responses = model_ask(model=args.model, input_prompt=[fill_prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        second_ans = responses.split('\n\n')[0].replace('\n\n','\n')
    
    return second_ans,fill_prompt,pred_second_ans

def double_check(args, previous_info,flag='gpt'):
    prompt = """You should double-check your answer"""

    if flag == 'gpt':
        prompt_ask = [
            {"role": "user", "content": previous_info[0]},
            {"role": "assistant", "content": previous_info[1]},
            {"role": "user", "content": prompt}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        double_check_ans = responses.choices[0].message.content.replace('\n\n','\n')
        #exact answer
        prompt_ask = [
            {"role": "user", "content": previous_info[0]},
            {"role": "assistant", "content": previous_info[1]},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": double_check_ans},
            {"role": "user", "content": args.direct_answer_trigger}]
        responses = GPT3_request(model=args.model, messages=prompt_ask, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                        temperature=args.temperature, stop='\n', args = args)
        pred_double_check_ans = responses.choices[0].message.content.replace('\n\n','\n')
    else:
        responses = model_ask(model=args.model, input_prompt=[prompt], max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
                                      temperature=args.temperature, stop='\n',load_model=args.load_model,tokenizer=args.tokenizer,args=args)
        double_check_ans = responses.split('\n\n')[0].replace('\n\n','\n')

    return double_check_ans, pred_double_check_ans
